# -*- coding: utf-8 -*-
import os
from apps.ssq.models import model
from dataset.ssq_dataset import LottoDataSet
from config.settings import SSQSettings

# 初始化数据集
lotto_dataset = LottoDataSet(train_data_rate=1)
# 创建保存权重的文件夹
if not os.path.exists(SSQSettings.CHECKPOINTS_PATH):
    os.mkdir(SSQSettings.CHECKPOINTS_PATH)
# 开始训练
model.fit(lotto_dataset.train_np_x, lotto_dataset.train_np_y, batch_size=SSQSettings.BATCH_SIZE,
          epochs=SSQSettings.EPOCHS)
# 保存模型
model.save_weights('{}/model_checkpoint_x'.format(SSQSettings.CHECKPOINTS_PATH))
